/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/* \file
   \brief Execution environment
*/

#pragma once

#include <stdexcept>
#include <list>
#include <vector>

#include "cutlass/library/library.h"
#include "cutlass/util/distribution.h"

#include "enumerated_types.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace profiler {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Device memory allocation
class DeviceAllocation {
private:
    /// Data type of contained elements
    library::NumericTypeID type_;

    /// Gets the stride between elements
    size_t batch_stride_;

    /// Capacity in elements of device allocation
    size_t capacity_;

    /// Pointer to device memory
    void* pointer_;

    /// Layout type ID
    library::LayoutTypeID layout_;

    /// Stride vector
    std::vector<int> stride_;

    /// Extent vector
    std::vector<int> extent_;

    /// Support allocating a 'batch' of non-overlapping tensors in contiguous
    /// memory
    int batch_count_;

    /// Buffer holding TensorRef instance to recently allocated memory
    std::vector<uint8_t> tensor_ref_buffer_;

public:
    //
    // Static member functions
    //

    /// Determines the number of bytes needed to represent this numeric type
    static size_t bytes(library::NumericTypeID type, size_t capacity);

    /// Returns the stride of a packed layout
    static std::vector<int> get_packed_layout(library::LayoutTypeID layout_id,
                                              std::vector<int> const& extent);

    /// returns the capacity needed
    static size_t construct_layout(void* bytes, library::LayoutTypeID layout_id,
                                   std::vector<int> const& extent,
                                   std::vector<int>& stride);

    /// Returns true if two blocks have exactly the same value
    static bool block_compare_equal(library::NumericTypeID numeric_type,
                                    void const* ptr_A, void const* ptr_B,
                                    size_t capacity);

    /// Returns true if two blocks have approximately the same value
    static bool block_compare_relatively_equal(
            library::NumericTypeID numeric_type, void const* ptr_A,
            void const* ptr_B, size_t capacity, double epsilon,
            double nonzero_floor);

public:
    //
    // Methods
    //

    DeviceAllocation();

    DeviceAllocation(library::NumericTypeID type, size_t capacity);

    DeviceAllocation(library::NumericTypeID type,
                     library::LayoutTypeID layout_id,
                     std::vector<int> const& extent,
                     std::vector<int> const& stride = std::vector<int>(),
                     int batch_count = 1);

    ~DeviceAllocation();

    DeviceAllocation& reset();

    /// Allocates device memory of a given type and capacity
    DeviceAllocation& reset(library::NumericTypeID type, size_t capacity);

    /// Allocates memory for a given layout and tensor
    DeviceAllocation& reset(library::NumericTypeID type,
                            library::LayoutTypeID layout_id,
                            std::vector<int> const& extent,
                            std::vector<int> const& stride = std::vector<int>(),
                            int batch_count = 1);

    /// Returns a buffer owning the tensor reference
    std::vector<uint8_t>& tensor_ref() { return tensor_ref_buffer_; }

    bool good() const;

    /// Data type of contained elements
    library::NumericTypeID type() const;

    /// Pointer to start of device memory allocation
    void* data() const;

    /// Pointer to the first element of a batch
    void* batch_data(int batch_idx) const;

    /// Gets the layout type
    library::LayoutTypeID layout() const;

    /// Gets the stride vector
    std::vector<int> const& stride() const;

    /// Gets the extent vector
    std::vector<int> const& extent() const;

    /// Gets the number of adjacent tensors in memory
    int batch_count() const;

    /// Gets the stride (in units of elements) beteween items
    int64_t batch_stride() const;

    /// Gets the stride (in units of bytes) beteween items
    int64_t batch_stride_bytes() const;

    /// Capacity of allocation in number of elements
    size_t capacity() const;

    /// Capacity of allocation in bytes
    size_t bytes() const;

    /// Initializes a device allocation to a random distribution using cuRAND
    void initialize_random_device(int seed, Distribution dist);

    /// Initializes a host allocation to a random distribution using std::cout
    void initialize_random_host(int seed, Distribution dist);

    /// Initializes a device allocation to a random distribution using cuRAND
    void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits);

    /// Initializes a host allocation to a random distribution using std::cout
    void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits);

    /// Copies from an equivalent-sized tensor in device memory
    void copy_from_device(void const* ptr);

    /// Copies from an equivalent-sized tensor in device memory
    void copy_from_host(void const* ptr);

    /// Copies from an equivalent-sized tensor in device memory
    void copy_to_host(void* ptr);

    /// Writes a tensor to csv
    void write_tensor_csv(std::ostream& out);
};

using DeviceAllocationList = std::list<DeviceAllocation>;

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace profiler
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
